Skip to main content
Version: 1.0.8

Causal Inference

DoubleMLEstimator

from synapse.ml.causal import *
from pyspark.ml.classification import LogisticRegression
from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType, BooleanType

schema = StructType([
StructField("Treatment", BooleanType()),
StructField("Outcome", BooleanType()),
StructField("col2", DoubleType()),
StructField("col3", DoubleType()),
StructField("col4", DoubleType())
])


df = spark.createDataFrame([
(False, True, 0.30, 0.66, 0.2),
(True, False, 0.38, 0.53, 1.5),
(False, True, 0.68, 0.98, 3.2),
(True, False, 0.15, 0.32, 6.6),
(False, True, 0.50, 0.65, 2.8),
(True, True, 0.40, 0.54, 3.7),
(False, True, 0.78, 0.97, 8.1),
(True, False, 0.12, 0.32, 10.2),
(False, True, 0.35, 0.63, 1.8),
(True, False, 0.45, 0.57, 4.3),
(False, True, 0.75, 0.97, 7.2),
(True, True, 0.16, 0.32, 11.7)], schema
)

dml = (DoubleMLEstimator()
.setTreatmentCol("Treatment")
.setTreatmentModel(LogisticRegression())
.setOutcomeCol("Outcome")
.setOutcomeModel(LogisticRegression())
.setMaxIter(20))

dmlModel = dml.fit(df)
dmlModel.getAvgTreatmentEffect()
dmlModel.getConfidenceInterval()
Python API: DoubleMLEstimatorScala API: DoubleMLEstimatorSource: DoubleMLEstimator